[water] Support variadic reduction ops in Water dialect and add corresponding simplification pass#1053
Conversation
Signed-off-by: Martin Lücke <martin.luecke@amd.com>
| /// %1 = wave.sum %b init(%0) <scope> | ||
| /// %r = wave.sum %c init(%1) <scope> | ||
| template <typename ReductionOp> | ||
| struct ExpandVariadicReduction : public OpRewritePattern<ReductionOp> { |
There was a problem hiding this comment.
We have a trait for reductions, would it make sense to make this OpTraitRewritePattern? Very open for arguments here since traits don't provide named accessors... Related discussion here #992 (comment).
There was a problem hiding this comment.
Good idea in principle, but, as you said, traits don't give us named accessors or the typed create() builder. With only two reduction ops and the template giving us full type safety, I think the explicit instantiation is the better tradeoff here. Of course, we have to eventually remember to add new types of reductions to the patterns.add call.
As mentioned in the related discussion, we also have the option to model this as an interface. I don't have a strong opinion here.
| supports variadic inputs for faithful roundtripping with the Python | ||
| representation. This pass normalizes them before lowering, which | ||
| requires single-input reductions. | ||
| }]; |
There was a problem hiding this comment.
Could you document the differences between this pass and its python counterpart?
Signed-off-by: Martin Lücke <martin.luecke@amd.com>
Extends the Water dialect reduction ops (wave.sum, wave.max_element) to accept variadic inputs, matching the PyWave representation, where expand_graph tiles reduction inputs into a list of slices. This simplifies FX <-> MLIR roundtrips by allowing the dialect to directly represent the intermediate form, rather than requiring the Python side to decompose reductions before emission, track which reductions stem from this, and fuse them again for the roundtrip.
A new ExpandVariadicReductions pass chains N variadic inputs into N single-input reductions, each feeding its result as the next accumulator — a partial port of the logic in PyWave's decompose_reduce_ops pass. Both the Water emitter and FX importer have been updated to handle variadic forms in both directions.
A normal-form annotation for expanded reductions could be added to indicate where in the pipeline single-input reductions are expected, though currently this would only be relevant for codegen, I think.